import torch
import torch.nn as nn
import torch.nn.functional as F

'''
    by ChatGPT
'''

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value):
        batch_size = query.size(0)

        # Linear transformations
        query = self.W_q(query)
        key = self.W_k(key)
        value = self.W_v(value)

        # Reshape for multi-heads
        query = query.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32))
        attention_weights = F.softmax(scores, dim=-1)
        attention_output = torch.matmul(attention_weights, value)

        # Reshape and concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # Final linear layer
        return self.W_o(attention_output)

class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionwiseFeedforward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super(TransformerLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedforward(d_model, d_ff)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Self-attention
        residual = x
        x = self.layer_norm1(x + self.self_attn(x, x, x))

        # Feedforward
        x = self.layer_norm2(x + self.ffn(x))

        return x

class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers):
        super(Transformer, self).__init__()
        self.layers = nn.ModuleList([TransformerLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Example usage
batch_size = 64
seq_length = 20
embedding_size = 256
num_heads = 8
d_ff = 512
num_layers = 6

# Input tensor
inputs = torch.randn(batch_size, seq_length, embedding_size)

# Instantiate Transformer model
transformer = Transformer(embedding_size, num_heads, d_ff, num_layers)

# Forward pass
outputs = transformer(inputs)
print(outputs.size())  # Output tensor shape: [batch_size, seq_length, embedding_size]